{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gYY_G2uoEdfL"
      },
      "source": [
        "Copyright 2023 Google LLC.\n",
        "\n",
        "SPDX-License-Identifier: Apache-2.0"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "CgmpKemkFJfx"
      },
      "outputs": [],
      "source": [
        "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fpPNFEWT3afm"
      },
      "source": [
        "# Instructions\n",
        "\n",
        "Follow steps 2 - 4 [here](https://research.google.com/colaboratory/local-runtimes.html)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "rwE_hMItvKIk"
      },
      "outputs": [],
      "source": [
        "!nvidia-smi"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SIPcmjCFSeOm"
      },
      "source": [
        "# Notebook specific installs"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "0PL58CHSSewD"
      },
      "outputs": [],
      "source": [
        "!pip3 install opencv-python h5py tensorboard\n",
        "!pip install moviepy --upgrade"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "1oT6Holf-rDx"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "import flax.linen as nn\n",
        "import haiku as hk\n",
        "import jax\n",
        "import jax.numpy as jnp\n",
        "import numpy as np\n",
        "import tqdm\n",
        "\n",
        "import chex\n",
        "import pickle\n",
        "\n",
        "import datetime\n",
        "import flax.jax_utils as flax_utils\n",
        "\n",
        "import matplotlib.pyplot as plt\n",
        "from importlib import reload"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2FswgGpr_NWo"
      },
      "outputs": [],
      "source": [
        "num_devices = jax.local_device_count()\n",
        "jax.devices()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "wJIhe7E7_TFX"
      },
      "outputs": [],
      "source": [
        "from hct import ndp\n",
        "from hct.common import utils"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4mMq1JpKIhls"
      },
      "outputs": [],
      "source": [
        "import h5py"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Ps5BILIiI3zs"
      },
      "outputs": [],
      "source": [
        "# Download file to local folder first\n",
        "with h5py.File('episode_1.hdf5', \"r\") as f:\n",
        "  f.visit(lambda name: print(name))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "rlKAHP2M6HIj"
      },
      "outputs": [],
      "source": [
        "with h5py.File('episode_1.hdf5', \"r\") as f:\n",
        "  actions = np.array(f['action'])\n",
        "  images = np.array(f['observations']['images']['top'])\n",
        "  qpos = np.array(f['observations']['qpos'])\n",
        "  qvel = np.array(f['observations']['qvel'])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "QnyMxzEWJJ0w"
      },
      "outputs": [],
      "source": [
        "print(actions.shape, images.shape, qpos.shape, qvel.shape)\n",
        "action_dim = actions.shape[1]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "tfXpywcKJwSK"
      },
      "outputs": [],
      "source": [
        "from moviepy.editor import ImageSequenceClip"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "wsPdBeGLJwox"
      },
      "outputs": [],
      "source": [
        "clip = ImageSequenceClip([img for img in images], fps=50)\n",
        "clip.ipython_display(fps=50)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MSoP-MTJYjxK"
      },
      "source": [
        "## Pre-process data: raw_data --\u003e raw_norm_data --\u003e hct_data"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "RkdG8NAACY8D"
      },
      "outputs": [],
      "source": [
        "import cv2"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "5FEzbtnvrK2z"
      },
      "outputs": [],
      "source": [
        "# Resize images \u0026 normalize all data\n",
        "images_resized = []\n",
        "for img in images:\n",
        "  images_resized.append(cv2.resize(img, (320, 240)))\n",
        "\n",
        "hf_obs = np.concatenate((qpos, qvel), axis=1)\n",
        "\n",
        "raw_data = {'images': np.array(images_resized),\n",
        "            'hf_obs': hf_obs,\n",
        "            'actions': actions}\n",
        "\n",
        "norm_stats = utils.compute_norm_stats(raw_data)\n",
        "\n",
        "raw_norm_data = utils.normalize(raw_data, *norm_stats)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "BECb6E4hRq55"
      },
      "outputs": [],
      "source": [
        "\"\"\"Batch data into following format:\n",
        "for t = 0, 1, ....\n",
        "example = {\n",
        "  'image': s_t,\n",
        "  'actions': U_t: (num_actions, action_dim) = [u_t(tau_0), ..., u_t(tau_{M-1})]\n",
        "  'hf_obs': X_t: (num_actions, num_hf_obs_per_action+1, x_dim) = [x_t^0, ..., x_t^{M-1}]\n",
        "}\n",
        "where X_t[0][-1] = x_t(0), coinciding with s_t and U_t[0]\n",
        "\"\"\"\n",
        "norm_images = raw_norm_data['images']\n",
        "norm_hf_obs = raw_norm_data['hf_obs']\n",
        "norm_actions = raw_norm_data['actions']\n",
        "\n",
        "# Set constants\n",
        "num_actions = 5\n",
        "num_hf_obs_per_action = 1  # current dataset has action-freq = hf_state-freq\n",
        "\n",
        "# Pre-tile hf_obs array with first observation\n",
        "init_tile = np.array([norm_hf_obs[0]]*num_hf_obs_per_action)\n",
        "norm_hf_obs = np.concatenate((init_tile, norm_hf_obs), axis=0)\n",
        "\n",
        "hct_data = {\n",
        "    'images': [],\n",
        "    'hf_obs': [],\n",
        "    'actions': []\n",
        "}\n",
        "\n",
        "# Step through control actions (can also step through images if that is the base frequency)\n",
        "final_idx = len(norm_actions) - num_actions\n",
        "for idx in range(0, final_idx+1, num_actions):\n",
        "  hct_data['images'].append(norm_images[idx])  # current dataset has image-freq = control-freq\n",
        "  hct_data['actions'].append(norm_actions[idx:idx+num_actions,...])\n",
        "\n",
        "  hf_obs_idx = num_hf_obs_per_action + idx * num_hf_obs_per_action\n",
        "\n",
        "  state_obs = []\n",
        "  for _ in range(num_actions):\n",
        "    state_obs.append(norm_hf_obs[hf_obs_idx-num_hf_obs_per_action:hf_obs_idx+1,...])\n",
        "    hf_obs_idx += num_hf_obs_per_action\n",
        "\n",
        "  hct_data['hf_obs'].append(np.stack(state_obs, axis=0))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "EtsQgQuCCoXq"
      },
      "outputs": [],
      "source": [
        "hct_data = {key: np.array(data) for key, data in hct_data.items()}\n",
        "print([(key, arr.shape) for key, arr in hct_data.items()])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2N56TrJu4Z75"
      },
      "outputs": [],
      "source": [
        "fig, axs = plt.subplots(7, 2, figsize=(10, 20))\n",
        "\n",
        "for i in range(action_dim):\n",
        "  row, col = np.unravel_index(i, (7, 2))\n",
        "  for t, chunk in enumerate(hct_data['actions']):\n",
        "    _ = axs[row][col].plot(np.arange(t*num_actions, (t+1)*num_actions), chunk[:, i])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qkyEzwS9CEGB"
      },
      "source": [
        "## Setup data-loader"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "XMXBrNHtgzvO"
      },
      "outputs": [],
      "source": [
        "# For NDP - only use x_t(0)\n",
        "hct_data['hf_obs'] = hct_data['hf_obs'][:, 0, -1, ...]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "KI4-giZvoxhV"
      },
      "outputs": [],
      "source": [
        "data_prng = hk.PRNGSequence(jax.random.PRNGKey(654321))\n",
        "num_data = len(hct_data['images'])\n",
        "shuffle_idxs = jax.random.permutation(next(data_prng), num_data)\n",
        "train_ratio = 0.8\n",
        "num_train = int(train_ratio * num_data)\n",
        "\n",
        "training_data = {\n",
        "    key: arr[shuffle_idxs][:num_train] for key, arr in hct_data.items()\n",
        "}\n",
        "eval_data = {key: arr[shuffle_idxs][num_train:] for key, arr in hct_data.items()}\n",
        "\n",
        "train_data_manager = utils.BatchManager(next(data_prng), training_data,\n",
        "                                        batch_size=int(0.5*num_train))\n",
        "\n",
        "eval_data_manager = utils.BatchManager(next(data_prng), eval_data,\n",
        "                                       len(eval_data['images']))\n",
        "\n",
        "sample_batch = train_data_manager.next_batch()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JJs2m0txB55L"
      },
      "source": [
        "## Setup Model and Trainstate"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "KpDk8ox6CIhp"
      },
      "outputs": [],
      "source": [
        "model_prng = hk.PRNGSequence(jax.random.PRNGKey(123456))\n",
        "\n",
        "loss = lambda u_true, u_pred: jnp.sum(jnp.square(u_true - u_pred))\n",
        "action_dim = hct_data['actions'].shape[-1]\n",
        "\n",
        "model = ndp.NDP(\n",
        "    action_dim=action_dim,\n",
        "    num_actions=num_actions,\n",
        "    loss_fnc=loss,\n",
        "    activation=nn.relu,\n",
        "    zs_dim=32,\n",
        "    zs_width=64,\n",
        "    zo_dim=16,\n",
        "    num_basis_fncs=5)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "my8S15HhKwYl"
      },
      "outputs": [],
      "source": [
        "learning_rate = 1e-2\n",
        "weight_decay = 0.\n",
        "train_state = ndp.create_ndp_train_state(\n",
        "    model, next(model_prng), learning_rate, weight_decay,\n",
        "    sample_batch['images'], sample_batch['hf_obs']\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "1YtImZovMcrw"
      },
      "outputs": [],
      "source": [
        "utils.param_count(train_state.params)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "s5M52J_3MjXf"
      },
      "outputs": [],
      "source": [
        "# Replicate model across devices\n",
        "train_state = flax_utils.replicate(train_state)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5ahmSEt2B9UJ"
      },
      "source": [
        "## Setup Eval"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "UFKZUHvaCHoC"
      },
      "outputs": [],
      "source": [
        "def eval(eval_batch_manager: utils.BatchManager,\n",
        "         ts: utils.TrainStateBN,\n",
        "         key: chex.PRNGKey,\n",
        "         num_devices: int) -\u003e float:\n",
        "  \"\"\"Do loss eval.\"\"\"\n",
        "\n",
        "  prng = hk.PRNGSequence(key)\n",
        "  num_eval_batches = eval_batch_manager.num_batches\n",
        "\n",
        "  eval_loss = 0.\n",
        "  for _ in range(num_eval_batches):\n",
        "    eval_batch = eval_batch_manager.next_pmapped_batch(num_devices)\n",
        "    batch_loss, _ = ndp.optimize_ndp(\n",
        "        ts, eval_batch['images'], eval_batch['hf_obs'], eval_batch['actions'])\n",
        "    eval_loss += batch_loss[0]\n",
        "\n",
        "  return eval_loss / num_eval_batches\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "vY9dthBMCJPy"
      },
      "source": [
        "## Create Save dirs"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "BmIAGVJBN-u1"
      },
      "outputs": [],
      "source": [
        "timestamp = datetime.datetime.now().strftime(f'%Y-%m-%d-%H:%M:%S')\n",
        "exp_dir = '/tmp' + timestamp\n",
        "chk_subdir = 'ndp'\n",
        "chk_dir = os.path.join(exp_dir, chk_subdir)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Bwhzad6kCRDo"
      },
      "source": [
        "## Do Optimization"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6bJYIDFEfHh3"
      },
      "outputs": [],
      "source": [
        "num_train_steps = 300\n",
        "log_every = 5\n",
        "eval_every = 10\n",
        "save_every = 50\n",
        "\n",
        "for idx in tqdm.tqdm(range(num_train_steps)):\n",
        "\n",
        "  train_batch = train_data_manager.next_pmapped_batch(num_devices)\n",
        "  batch_loss, train_state = ndp.optimize_ndp(\n",
        "      train_state, train_batch['images'], train_batch['hf_obs'], train_batch['actions'])\n",
        "\n",
        "  if idx % log_every == 0:\n",
        "    print('idx: ', idx, 'train_loss:', batch_loss[0])\n",
        "\n",
        "  if idx % eval_every == 0:\n",
        "    eval_loss = eval(eval_data_manager, train_state, next(data_prng), num_devices)\n",
        "    print('idx: ', idx, 'eval_loss:', eval_loss)\n",
        "\n",
        "  if (idx+1) % save_every == 0:\n",
        "    utils.save_model(chk_dir, idx + 1, save_every,\n",
        "                     flax_utils.unreplicate(train_state))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "UoxFESd2T8IA"
      },
      "source": [
        "## Test eval on full episode"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "h3gPHczjaaNu"
      },
      "outputs": [],
      "source": [
        "# Skip if proceeding directly from training\n",
        "train_state = utils.restore_model(chk_dir, flax_utils.unreplicate(train_state))\n",
        "train_state = flax_utils.replicate(train_state)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "L7zrijgoUqZm"
      },
      "outputs": [],
      "source": [
        "model_params = {'params': flax_utils.unreplicate(train_state).params,\n",
        "                'batch_stats': flax_utils.unreplicate(train_state).batch_stats}"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "9agGXLDCT9yv"
      },
      "outputs": [],
      "source": [
        "# Generate some predictions\n",
        "actions_pred, losses = model.apply(model_params,\n",
        "                                   hct_data['images'],\n",
        "                                   hct_data['hf_obs'],\n",
        "                                   hct_data['actions'],\n",
        "                                   method=model.compute_augmented_flow)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Urlg4FXK0-z1"
      },
      "outputs": [],
      "source": [
        "# Unnormalize the predictions\n",
        "u_true = jax.vmap(utils.unnormalize, in_axes=(0, None, None))(\n",
        "    hct_data['actions'], norm_stats[0]['actions'], norm_stats[1]['actions'])\n",
        "u_pred = jax.vmap(utils.unnormalize, in_axes=(0, None, None))(\n",
        "    actions_pred, norm_stats[0]['actions'], norm_stats[1]['actions'])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ItE2YwBVVHBW"
      },
      "outputs": [],
      "source": [
        "# Plot comparison\n",
        "\n",
        "eval_range = [0, 20]\n",
        "\n",
        "fig, axs = plt.subplots(7, 2, figsize=(30, 30))\n",
        "for i in range(action_dim):\n",
        "  row, col = np.unravel_index(i, (7, 2))\n",
        "  for t in range(eval_range[0], eval_range[1]):\n",
        "    lplot = axs[row][col].plot(\n",
        "        np.arange(t*num_actions, (t+1)*num_actions), u_true[t, :, i], '--',\n",
        "        linewidth=2)\n",
        "    color = lplot[0].get_color()\n",
        "\n",
        "    _ = axs[row][col].plot(\n",
        "        np.arange(t*num_actions, (t+1)*num_actions), u_pred[t, :, i], '-',\n",
        "        linewidth=1.5, color=color)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "TaTmCXl-_n4C"
      },
      "outputs": [],
      "source": []
    }
  ],
  "metadata": {
    "colab": {
      "last_runtime": {
        "build_target": "//learning/deepmind/public/tools/ml_python:ml_notebook",
        "kind": "private"
      },
      "private_outputs": true,
      "provenance": [
        {
          "file_id": "16L1qsmG0B-HQWZYn_A-iIXDP6iz2YKCG",
          "timestamp": 1680817019409
        },
        {
          "file_id": "1cOK4xh3_IshDQz6GjMDG3prg7b9Xdb6G",
          "timestamp": 1679955209952
        }
      ],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
